{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpath = \"./data/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>msno</th>\n",
       "      <th>song_id</th>\n",
       "      <th>source_system_tab</th>\n",
       "      <th>source_screen_name</th>\n",
       "      <th>source_type</th>\n",
       "      <th>city</th>\n",
       "      <th>bd</th>\n",
       "      <th>gender</th>\n",
       "      <th>registered_via</th>\n",
       "      <th>...</th>\n",
       "      <th>expiration_date</th>\n",
       "      <th>song_length</th>\n",
       "      <th>genre_ids</th>\n",
       "      <th>language</th>\n",
       "      <th>mult_genre</th>\n",
       "      <th>user_pop</th>\n",
       "      <th>item_pop</th>\n",
       "      <th>user_rate</th>\n",
       "      <th>item_rate</th>\n",
       "      <th>lfm_reco</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=</td>\n",
       "      <td>WmHKgKMlp1lQMecNdNvDMkvIycZYHnFwDT72I5sIssc=</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "      <td>...</td>\n",
       "      <td>13</td>\n",
       "      <td>-0.14208</td>\n",
       "      <td>24</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.401</td>\n",
       "      <td>2.944</td>\n",
       "      <td>0.366</td>\n",
       "      <td>0.504</td>\n",
       "      <td>0.38083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=</td>\n",
       "      <td>y/rsZ9DC7FwK5F2PK2D5mj+aOBUJAjuu3dZ14NgE0vM=</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "      <td>...</td>\n",
       "      <td>13</td>\n",
       "      <td>0.45662</td>\n",
       "      <td>25</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.401</td>\n",
       "      <td>32.878</td>\n",
       "      <td>0.366</td>\n",
       "      <td>0.625</td>\n",
       "      <td>0.53185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>/uQAlrAkaczV+nWCd2sPF2ekvXPRipV7q0l+gbLuxjw=</td>\n",
       "      <td>8eZLFOdGVdXBSqoAv5nsLigeH2BvKXzTQYtUM53I0k4=</td>\n",
       "      <td>0</td>\n",
       "      <td>16</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>...</td>\n",
       "      <td>12</td>\n",
       "      <td>0.42822</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.594</td>\n",
       "      <td>-0.072</td>\n",
       "      <td>0.144</td>\n",
       "      <td>0.400</td>\n",
       "      <td>-0.88885</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=</td>\n",
       "      <td>ztCf8thYsS4YN3GcIL/bvoxLm/T5mYBVKOO4C9NiVfQ=</td>\n",
       "      <td>7</td>\n",
       "      <td>11</td>\n",
       "      <td>7</td>\n",
       "      <td>3</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "      <td>...</td>\n",
       "      <td>13</td>\n",
       "      <td>0.23750</td>\n",
       "      <td>25</td>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.034</td>\n",
       "      <td>-0.029</td>\n",
       "      <td>0.296</td>\n",
       "      <td>0.226</td>\n",
       "      <td>-0.18739</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=</td>\n",
       "      <td>MKVMpslKcQhMaFEgcEQhEfi5+RZhMYlU3eRDpySrH8Y=</td>\n",
       "      <td>7</td>\n",
       "      <td>11</td>\n",
       "      <td>7</td>\n",
       "      <td>3</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "      <td>...</td>\n",
       "      <td>13</td>\n",
       "      <td>-0.30702</td>\n",
       "      <td>32</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.034</td>\n",
       "      <td>-0.072</td>\n",
       "      <td>0.296</td>\n",
       "      <td>0.400</td>\n",
       "      <td>0.98938</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                          msno  \\\n",
       "0   0  V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=   \n",
       "1   1  V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=   \n",
       "2   2  /uQAlrAkaczV+nWCd2sPF2ekvXPRipV7q0l+gbLuxjw=   \n",
       "3   3  1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=   \n",
       "4   4  1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=   \n",
       "\n",
       "                                        song_id  source_system_tab  \\\n",
       "0  WmHKgKMlp1lQMecNdNvDMkvIycZYHnFwDT72I5sIssc=                  3   \n",
       "1  y/rsZ9DC7FwK5F2PK2D5mj+aOBUJAjuu3dZ14NgE0vM=                  3   \n",
       "2  8eZLFOdGVdXBSqoAv5nsLigeH2BvKXzTQYtUM53I0k4=                  0   \n",
       "3  ztCf8thYsS4YN3GcIL/bvoxLm/T5mYBVKOO4C9NiVfQ=                  7   \n",
       "4  MKVMpslKcQhMaFEgcEQhEfi5+RZhMYlU3eRDpySrH8Y=                  7   \n",
       "\n",
       "   source_screen_name  source_type  city  bd  gender  registered_via  \\\n",
       "0                   7            2     1   0       2               7   \n",
       "1                   7            2     1   0       2               7   \n",
       "2                  16            9     1   0       2               4   \n",
       "3                  11            7     3   5       1               9   \n",
       "4                  11            7     3   5       1               9   \n",
       "\n",
       "     ...     expiration_date  song_length  genre_ids  language  mult_genre  \\\n",
       "0    ...                  13     -0.14208         24         4           0   \n",
       "1    ...                  13      0.45662         25         4           0   \n",
       "2    ...                  12      0.42822         12         2           0   \n",
       "3    ...                  13      0.23750         25         8           0   \n",
       "4    ...                  13     -0.30702         32         0           0   \n",
       "\n",
       "   user_pop  item_pop  user_rate  item_rate  lfm_reco  \n",
       "0    -0.401     2.944      0.366      0.504   0.38083  \n",
       "1    -0.401    32.878      0.366      0.625   0.53185  \n",
       "2    -0.594    -0.072      0.144      0.400  -0.88885  \n",
       "3    -0.034    -0.029      0.296      0.226  -0.18739  \n",
       "4    -0.034    -0.072      0.296      0.400   0.98938  \n",
       "\n",
       "[5 rows x 21 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_final = pd.read_csv(dpath+\"test_final.csv\")\n",
    "test_final.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2556790, 21)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_final.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#columns和header参数的区别\n",
    "test_id = test_final[\"id\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.array(test_final.iloc[:,3:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 29.6 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "lgbm = pickle.load(open(\"./model/lgbm.pkl\",\"rb\"))\n",
    "y_proba = lgbm.predict_proba(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_proba[:20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.提交结果  \n",
    "kaggle最终auc得分为0.60082"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def submit(submission_file):\n",
    "    f = open(submission_file,\"w+\")\n",
    "    ocolnames = [\"id\",\"target\"]\n",
    "    f.write(\",\".join(ocolnames)+\"\\n\")\n",
    "    outcols = []\n",
    "    for i in range(len(test_id)):\n",
    "        col1 = test_id[i]\n",
    "        col2 = y_proba[:,1][i]\n",
    "        outcols = [str(col1),str(col2)]\n",
    "        f.write(\",\".join(outcols)+\"\\n\")\n",
    "    f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "submit(\"./submission/submission_file.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "submission_file = pd.read_csv(\"./submission/submission_file.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.398106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.438361</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.122106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.089994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.118133</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>0.124251</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>0.124989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>0.453406</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>0.188738</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>0.803759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>10</td>\n",
       "      <td>0.801334</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>11</td>\n",
       "      <td>0.279935</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>12</td>\n",
       "      <td>0.247023</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>13</td>\n",
       "      <td>0.681488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>14</td>\n",
       "      <td>0.258585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>15</td>\n",
       "      <td>0.629014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>16</td>\n",
       "      <td>0.627083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>17</td>\n",
       "      <td>0.533829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>18</td>\n",
       "      <td>0.394208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>0.471810</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    id    target\n",
       "0    0  0.398106\n",
       "1    1  0.438361\n",
       "2    2  0.122106\n",
       "3    3  0.089994\n",
       "4    4  0.118133\n",
       "5    5  0.124251\n",
       "6    6  0.124989\n",
       "7    7  0.453406\n",
       "8    8  0.188738\n",
       "9    9  0.803759\n",
       "10  10  0.801334\n",
       "11  11  0.279935\n",
       "12  12  0.247023\n",
       "13  13  0.681488\n",
       "14  14  0.258585\n",
       "15  15  0.629014\n",
       "16  16  0.627083\n",
       "17  17  0.533829\n",
       "18  18  0.394208\n",
       "19  19  0.471810"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "submission_file.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2556790, 2)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "submission_file.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
